# plot.py
"""
Contains functions for plotting simulation results.
"""
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
from collections import defaultdict
import random # For report generation plot
import config # For config parameters like NUM_VALUE_BINS, VALUE_RANGE etc.
from utils import discretize, get_report_from_action # For report generation plot
from scipy.stats import linregress # For linear regression

def plot_multiple_results(results_map, truthful_benchmark_name):
    """
    Generates plots comparing the results of multiple scenarios.
    Args:
        results_map (dict): A dictionary where keys are scenario names
                            and values are the simulation result tuples.
                            result tuple: (welfare, agent_util, cost, B_final, mu_final, no_alloc, agents)
        truthful_benchmark_name (str): Key for the truthful baseline.
    """
    num_scenarios = len(results_map)
    base_colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink', 'gray']
    base_linestyles = ['-', '--', '-.', ':', '-', '--', '-.', ':'] # Cycle through more styles

    
    colors = (base_colors * (num_scenarios // len(base_colors) + 1))[:num_scenarios]
    linestyles = (base_linestyles * (num_scenarios // len(base_linestyles) + 1))[:num_scenarios]

    names_list = list(results_map.keys())
    results_list = list(results_map.values())
    plt.rcParams.update({'font.size': 12})

    plt.figure(figsize=(16, 5.5)) # 2 rows, 2 columns
    agent_to_plot = 0
    window_size = config.PLOT_WINDOW_SIZE
    resource_dim_to_plot = 0

    # Plot 1: Total Welfare
    plt.subplot(1, 3, 1)
    for i, results in enumerate(results_list):
        if names_list[i] == truthful_benchmark_name:
            continue
        welfare_hist = results[0]
        welfare_maximum = results_map[truthful_benchmark_name][0]
        valid_window = min(window_size, len(welfare_hist))
        if valid_window > 0 and len(welfare_hist) >= valid_window:
            welfare_smooth = np.convolve(np.divide(welfare_hist, welfare_maximum) * 100, np.ones(valid_window)/valid_window, mode='valid')
            plt.plot(welfare_smooth, label=names_list[i], color=colors[i], linestyle=linestyles[i])
        elif len(welfare_hist) > 0:
            plt.plot(np.divide(welfare_hist, welfare_maximum) * 100, label=f'{names_list[i]} (Raw)', color=colors[i], linestyle=linestyles[i])
    plt.xlabel('Trial'); plt.ylabel('Normalized Social Welfare'); plt.title('Social Welfare Normalized by Offline Optimum')
    plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
    plt.legend(loc='lower right'); plt.grid(True)

    # Plot 2: Agent Reports (only for Q-learning scenarios)
    plt.subplot(1, 3, 2)
    plotted_any_boxplot = False
    bin_labels_common = None
    value_bins_for_truth_line = None
    num_eval_rounds = 1000
    
    q_learn_scenario_indices = [idx for idx, name in enumerate(names_list) if "Q-Learn" in name]

    for plot_idx, res_map_idx in enumerate(q_learn_scenario_indices):
        name_q = names_list[res_map_idx]
        results_q = results_list[res_map_idx]
        color_q = colors[res_map_idx]
        final_agents = results_q[6] # Agents is the 7th element

        if final_agents and len(final_agents) > agent_to_plot:
            current_agent = final_agents[agent_to_plot]
            current_agent.epsilon = 0
            reports_for_plot = defaultdict(list)
            eval_mu_for_agent = np.zeros(config.COST_DIM)
            if "Alg1" in name_q or "OFTRL" in name_q:
                 if results_q[4] and len(results_q[4]) > 0:
                    avg_final_mu_list = [mu for mu in results_q[4][-max(10, len(results_q[4])//10):] if isinstance(mu, np.ndarray) and mu.shape[0] > resource_dim_to_plot]
                    if avg_final_mu_list: eval_mu_for_agent = np.mean(avg_final_mu_list, axis=0)
            
            for _ in range(num_eval_rounds):
                v0 = random.uniform(config.MIN_VALUE_RANGE[agent_to_plot], config.MAX_VALUE_RANGE[agent_to_plot])
                state_for_report = current_agent.get_state(v0, random.randint(0, config.T - 1), eval_mu_for_agent)
                action_idx = current_agent.choose_action(state_for_report)
                report = get_report_from_action(action_idx, current_agent.num_report_actions, current_agent.report_range)
                reports_for_plot[discretize(v0, config.NUM_VALUE_BINS, config.VALUE_RANGE)].append(report)

            value_bins_reported = sorted(reports_for_plot.keys())
            if not value_bins_for_truth_line and value_bins_reported:
                value_bins_for_truth_line = value_bins_reported
            reports_by_bin = [reports_for_plot[b] for b in value_bins_reported]
            if not bin_labels_common and value_bins_reported:
                 bin_labels_common = [f'[{b/config.NUM_VALUE_BINS:.1f},{(b+1)/config.NUM_VALUE_BINS:.1f}]' for b in value_bins_reported]
            
            if reports_by_bin:
                # Adjust positions for multiple boxplots per bin group
                num_q_scenarios = len(q_learn_scenario_indices)
                group_width_total = 0.8 * num_q_scenarios
                offset_per_boxplot = 0.8
                base_positions = np.array(range(len(reports_by_bin))) * (num_q_scenarios + 1) # Wider spacing between groups
                current_positions = base_positions - (group_width_total / 2) + (plot_idx * offset_per_boxplot) + (offset_per_boxplot/2)

                bp = plt.boxplot(reports_by_bin, positions=current_positions, sym='', widths=0.6, patch_artist=True)
                for patch in bp['boxes']: patch.set_facecolor(color_q)
                plotted_any_boxplot = True
    
    if plotted_any_boxplot:
        if bin_labels_common:
             num_groups = len(bin_labels_common)
             num_q_scenarios = len(q_learn_scenario_indices)
             group_width_total = (num_q_scenarios -1 ) * 0.8 if num_q_scenarios > 1 else 0
             tick_positions = [j * (num_q_scenarios + 1) + group_width_total / 2 for j in range(num_groups)]
             plt.xticks(tick_positions, bin_labels_common, rotation=45, ha='right')

        plt.xlabel('True Value Bin ($v_{t,i}$)')
        plt.ylabel('Reported Value ($u_{t,i}$)')
        plt.title(f'Agent\'s Reports vs True Value')
        
        custom_lines_bp = [plt.Line2D([0], [0], color=colors[list(results_map.keys()).index(names_list[idx])], lw=4) for idx in q_learn_scenario_indices]
        legend_labels_bp = [names_list[idx] for idx in q_learn_scenario_indices]

        if value_bins_for_truth_line:
            bin_centers_y = [(b + 0.5) / config.NUM_VALUE_BINS for b in value_bins_for_truth_line]
            x_pos_truth_line = [j * (num_q_scenarios + 1) + group_width_total / 2 for j in range(len(value_bins_for_truth_line))]
            plt.plot(x_pos_truth_line, bin_centers_y, 'k--', label='Truthful ($u=v$)')
            custom_lines_bp.insert(0, plt.Line2D([0], [0], color='black', linestyle='--', lw=2))
            legend_labels_bp.insert(0, 'Truthful ($u=v$)')
        plt.legend(custom_lines_bp, legend_labels_bp, loc='lower right')
        plt.grid(True, axis='y')
    else:
        plt.text(0.5, 0.5, 'No Q-learning agent report data', ha='center', va='center')

    # Plot 3: Average Cost
    # plt.subplot(1, 3, 3)
    # for i, results in enumerate(results_list):
    #     if names_list[i] == truthful_benchmark_name:
    #         continue
    #     cost_hist = results[2]
    #     cost_np = np.array(cost_hist)
    #     cost_avg_per_round = np.mean(cost_np, axis=1) / config.T if config.T > 0 and cost_np.ndim > 1 and cost_np.shape[0] > 0 else np.zeros(len(cost_hist))
    #     valid_window = min(window_size, len(cost_avg_per_round))
    #     if valid_window > 0 and len(cost_avg_per_round) >= valid_window:
    #         cost_smooth = np.convolve(cost_avg_per_round, np.ones(valid_window)/valid_window, mode='valid')
    #         plt.plot(cost_smooth, label=names_list[i], color=colors[i], linestyle=linestyles[i])
    #     elif len(cost_avg_per_round) > 0:
    #         plt.plot(cost_avg_per_round, label=f'{names_list[i]} (Raw)', color=colors[i], linestyle=linestyles[i])
    # plt.axhline(y=config.RHO, color='black', linestyle=':', label=fr'Constraint $\rho$={config.RHO}')
    # plt.xlabel('Trial'); plt.ylabel('Average Cost/Round (Smoothed)'); plt.title('Planner Average Cost')
    # plt.legend(loc='lower right'); plt.grid(True)

    # Plot 4: FTRL Error
    plt.subplot(1, 3, 3)
    for i, results in enumerate(results_list):
        if names_list[i] != 'Q-Learning vs our O-FTRL-FP':
            continue
        ftrl_error = [result[0] for result in results[7]]
        fixed_point_error = [result[1] for result in results[7]]
        valid_window = min(window_size, len(ftrl_error))
        if valid_window > 0 and len(ftrl_error) >= valid_window:
            ftrl_error_smooth = np.convolve(ftrl_error, np.ones(valid_window)/valid_window, mode='valid')
            plt.plot(ftrl_error_smooth * 100, label='FTRL error', color=colors[i], linestyle=linestyles[i])
            fixed_point_error_smooth = np.convolve(fixed_point_error, np.ones(valid_window)/valid_window, mode='valid')
            plt.plot(fixed_point_error_smooth * 100, label='Fixed point error', color=colors[i - 1], linestyle=linestyles[i - 1])
        elif len(ftrl_error) > 0:
            plt.plot(ftrl_error * 100, label='Relative Error with FTRL', color=colors[i], linestyle=linestyles[i])
            plt.plot(fixed_point_error * 100, label='Relative Error with Fixed Point', color=colors[i - 1], linestyle=linestyles[i - 1])
    plt.xlabel('Trial'); plt.ylabel('Relative Error in L2 Norm'); plt.title('Relative Error due to Fixed Point Approximation')
    plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
    plt.legend(loc='lower right'); plt.grid(True)

    plt.tight_layout()
    plt.savefig('plot.png')
    plt.show()
